/**
 * @author LKQ
 * @date 2022/2/22 21:11
 * @description
 */
public class Solution2 {
    public static void main(String[] args) {
        Solution2 solution = new Solution2();
        TreeNode t3 = new TreeNode(3, new TreeNode(4), new TreeNode(5)), t2 = new TreeNode(2, t3, new TreeNode(6)),
                t7 = new TreeNode(7, new TreeNode(8), new TreeNode(9)), t1 = new TreeNode(1, t2, t7);
        solution.flatten(t1);
    }
    TreeNode last  = null;
    public void flatten(TreeNode root) {
        if (root == null) {
            return;
        }
        flatten(root.right);
        flatten(root.left);
        root.right = last;
        root.left = null;
        last = root;
    }
}
